# '语义分割数据增强'
from osgeo import gdal
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import random
from tqdm import tqdm
# # 受postgis影响，重新设置环境变量
# os.environ['CPL_ZIP_ENCODING'] = 'UTF-8'
# os.environ['PROJ_LIB'] = r'D:\anaconda3\envs\pytorch_GPU\Lib\site-packages\pyproj\proj_dir\share\proj'
# os.environ['GDAL_DATA'] = r'D:\anaconda3\envs\pytorch_GPU\Library\share'

# ----------------读取栅格数据-----------------
def readRaster(fileName, xoff=0, yoff=0, data_width=0, data_height=0):
    dataset = gdal.Open(fileName)
    if dataset == None:
        print(fileName + "文件无法打开")
    #  栅格矩阵的列数
    width = dataset.RasterXSize
    #  栅格矩阵的行数
    height = dataset.RasterYSize
    #  波段数
    bands = dataset.RasterCount
    #  获取数据
    if (data_width == 0 and data_height == 0):
        data_width = width
        data_height = height
    data = dataset.ReadAsArray(xoff, yoff, data_width, data_height)
    #  获取仿射矩阵信息
    geotrans = dataset.GetGeoTransform()
    #  获取投影信息
    proj = dataset.GetProjection()
    # print(width, height, bands, proj)
    return width, height, bands, data, geotrans, proj
#------------------- 保存栅格-------------
def saveRaster(im_data, im_geotrans, im_proj, path):
    # 判断数据像元类型
    # print(f'像元数据类型{im_data.dtype.name}')
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
    #数据的（c,h,w）
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape
    # 创建文件
    # 驱动器
    driver = gdal.GetDriverByName("GTiff")
    driver.Register()
    output_tif = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
    if (output_tif != None):
        output_tif.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        output_tif.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        output_tif.GetRasterBand(i + 1).WriteArray(im_data[i])
        # 刷新所有写入的缓存数据到磁盘
    output_tif.FlushCache()
    del output_tif
#------------------将文件名追加回TXT文件中----
def saveTXT(raser_path1,raster_path2,old_file,new_file):
    raser_path1 = raser_path1
    raster_path2 = raster_path2
    old_file = old_file
    new_file = new_file
    # print(f'旧文件{old_file},新文件{new_file}')
    with open(raser_path1, "r+") as f:  # 打开文件
        data1 = f.read()  # 读取文件
        # print(data)
        if old_file in data1:
            f.write(new_file + '\n')
            # print(f'在{raser_path1}中在写入文件{new_file}')
        else:
            with open(raster_path2, "r+") as f2:
                data2 = f2.read()  # 读取文件
                # print(data)
                if old_file in data2:
                    f2.write(new_file + '\n')
                    # print(f'在{raster_path2}中写入文件{new_file}')
#-----------------对标签文件进行染色-----------
def colorful(input_path, save_path):
    input_path = Image.fromarray(input_path)  # 将图像从numpy的数据格式转为PIL中的图像格式
    palette = []
    for i in range(256):
        palette.extend((i, i, i))
    palette[:3 * 21] = np.array([[0, 0, 0],
                                 [128, 0, 0],
                                 [0, 128, 0],
                                 [128, 128, 0],
                                 [0, 0, 128],
                                 [128, 0, 128],
                                 [0, 128, 128],
                                 [128, 128, 128],
                                 [64, 0, 0],
                                 [192, 0, 0],
                                 [64, 128, 0],
                                 [192, 128, 0],
                                 [64, 0, 128],
                                 [192, 0, 128],
                                 [64, 128, 128],
                                 [192, 128, 128],
                                 [0, 64, 0],
                                 [128, 64, 0],
                                 [0, 192, 0],
                                 [128, 192, 0],
                                 [0, 64, 128]
                                 ], dtype='uint8').flatten()

    input_path.putpalette(palette)
    input_path.save(save_path)
# -------------------数据增强方法--------------------------
# 数据增强方法一：水平翻转
def Raster_Hor(im_data,segm_img_name,label_read,segm_label_name,im_geotrans, im_proj):
    #  图像水平翻转
    im_data_hor = np.flip(im_data, axis=2)
    hor_path = img_path + "/" + str('hor') + segm_img_name
    # print(f'图像路径{hor_path}')
    saveRaster(im_data_hor, im_geotrans, im_proj, hor_path)
    #  标签水平翻转
    #  np.flip(im_data, axis=2)
    Hor = cv2.flip(label_read, 1)
    # print(f'读取图像路径{Hor}')

    hor_path1 = label_path + "/" + str('hor') + segm_label_name
    cv2.imwrite(hor_path1, Hor)
    #
    img_new = cv2.imread(hor_path1, cv2.IMREAD_GRAYSCALE)
    colorful(img_new, hor_path1)

    # plt.imsave(hor_path,Hor)
# 数据增强方法二：垂直翻转
def Raster_Dia(im_data,segm_img_name,label_read,segm_label_name,im_geotrans, im_proj):
    #  图像水平翻转
    im_data_hor = np.flip(im_data, axis=1)
    hor_path = img_path + "/" + str('dia') + segm_img_name
    # print(f'图像路径{hor_path}')
    saveRaster(im_data_hor, im_geotrans, im_proj, hor_path)
    #  标签水平翻转
    Hor = cv2.flip(label_read, 0)
    # print(f'读取图像路径{Hor}')
    hor_path1 = label_path + "/" + str('dia') + segm_label_name
    # print(hor_path)
    cv2.imwrite(hor_path1, Hor)
    img_new = cv2.imread(hor_path1, cv2.IMREAD_GRAYSCALE)
    colorful(img_new, hor_path1)
    # plt.imsave(hor_path,Hor)
# 数据增强方法三：随机缩放及旋转
def Raster_Ratation(im_data, segm_img_name, label_read, segm_label_name):
    # 随机角度
    rotateDeg = random.randint(15, 359)
    # 随机缩放因子
    rotateScale = random.uniform(0.9, 1.1)
    im_data = im_data.transpose(1, 2, 0)
    height, width = im_data.shape[:2]
    # print(label_read.shape)
    # 旋转矩阵
    M = cv2.getRotationMatrix2D((im_data.shape[1] // 2, im_data.shape[0] // 2), rotateDeg, rotateScale)
    image = cv2.warpAffine(im_data, M, (im_data.shape[1], im_data.shape[0]), flags=cv2.INTER_NEAREST,
                           borderValue=(255, 255, 255))
    label = cv2.warpAffine(label_read, M, (image.shape[1], image.shape[0]), flags=cv2.INTER_NEAREST)

    img1 = Image.fromarray(image)
    rat_path1 = os.path.join(img_path, str('rat') + segm_img_name)
    # rat_path1 = img_path + "\\" + str('rat') + segm_img_name
    img1.save(rat_path1)
    # 标签
    rat1_path = label_path + "/" + str('rat') + segm_label_name
    cv2.imwrite(rat1_path, label)
    img_new = cv2.imread(rat1_path, cv2.IMREAD_GRAYSCALE)
    colorful(img_new, rat1_path)
# -------------------总处理函数--------------------------
def data_enhance(img_path, label_path, name_keywords):
    # 数据读取
    segm_img = os.listdir(img_path)
    segm_label = os.listdir(label_path)
    if len(segm_img) == len(segm_label):
        sum_num = int(len(segm_img))
        print(f'图像数量为{len(segm_img)},标签数量为{len(segm_label)}')
        # 转换数量
        segm_img_num = len(segm_label)+1
        # 遍历文件
        for name in tqdm(range(sum_num)):
            # print(segm_img[name])
            # 图像
            # print(f'数据增强数量剩余{segm_img_num}，当前图像名称为{segm_img[name]}')
            fileName_path = os.path.join(img_path,segm_img[name])
            im_width, im_height, im_bands, im_data, im_geotrans, im_proj = readRaster(fileName_path)
            #标签
            labelName_path = os.path.join(label_path,segm_img[name])
            labe_width, labe_height, labe_bands, label_read, labe_geotrans, labe_proj = readRaster(labelName_path )
            if name_keywords == 'hor':
                # 数据增强方法一：水平翻转
                Raster_Hor(im_data,segm_img[name],label_read,segm_img[name],im_geotrans, im_proj)
            elif name_keywords == 'dia':
                # 数据增强方法二：垂直翻转
                 Raster_Dia(im_data, segm_img[name],label_read,segm_img[name],im_geotrans, im_proj)
            elif name_keywords == 'rat':
                # 数据增强方法三：随机缩放及旋转
                 Raster_Ratation(im_data, segm_img[name], label_read, segm_img[name])
            else:
                print('关键字错误')
            old_file1 = segm_img[name]#此时还带.tif
            # 分离文件名与文件后缀
            old_file, suffix = os.path.splitext(old_file1)

            new_file1 =name_keywords + segm_label[name]#此时还带tif
            new_file, suffix = os.path.splitext(new_file1)
            saveTXT(raser_path1,raster_path2,old_file,new_file)
            segm_img_num= segm_img_num-1
    else:
        print('图像与标签数据不一致')




if __name__ == "__main__":
    # 数据所在的文件夹
    data_path = r'H:\segm\voc2007\voc2007'
    # JPEGImages
    img_path = os.path.join(data_path,'JPEGImages')
    # SegmentationClass
    label_path = os.path.join(data_path,'SegmentationClass')
    # TXT文件路径
    raser_path1 = data_path + '/' +'ImageSets/Segmentation/train.txt'
    raster_path2 =data_path + '/' +'ImageSets/Segmentation/val.txt'

    # 当使用数据增强方法一(水平翻转)的时候使用hor，使用数据增强方法二的时候使用(水平翻转)dia,使用数据增强方法三使用rat（随机旋转及缩放）
    # 依次填写，一个方法完成后再调用
    data_enhance(img_path,label_path,'rat')



